import tensorflow as tf
import tensorflow.contrib.layers as layers


def inference(images, n_class, list_channels, type, keep_prob):
    """
    训练模型
    :param list_channels:各层输入输出通道
    :param images:输入图像批次
    :param n_class:分类数量
    :param keep_prob:Dropout参数
    :return:返回全连接得到的值，shape=[batch,n_class]
    """
    print(list_channels)
    # 1:conv1
    conv1 = conv_layer(input_tensor=images,
                       filter_size=5,
                       input_channel=list_channels[0],
                       output_channel=list_channels[1])

    # 2:pool1,norm1
    norm1 = pool_layer(conv1)

    # 3:conv2
    conv2 = conv_layer(input_tensor=norm1,
                       filter_size=3,
                       input_channel=list_channels[1],
                       output_channel=list_channels[2])

    # 4:pool2,norm2
    norm2 = pool_layer(conv2)

    # 5:full-connect1
    reshape = layers.flatten(norm2)  # 平坦化，将三维特征图拉伸为一维特征
    flt = reshape.get_shape()[1].value
    fc1 = tf.nn.relu(fc_layer(input_tensor=reshape,
                              input_channel=flt,
                              stddev_w=0.005,
                              output_channel=list_channels[3]))
    fc1 = tf.nn.dropout(fc1, keep_prob)

    # 6:full-connect2
    fc2 = tf.nn.relu(fc_layer(input_tensor=fc1,
                              input_channel=list_channels[3],
                              stddev_w=0.005,
                              output_channel=list_channels[4]))
    fc2 = tf.nn.dropout(fc2, keep_prob)

    # 7:softmax
    train_logits = fc_layer(fc2, list_channels[4], 0.005, n_class)
    if type == "test":
        softmax = tf.nn.softmax(train_logits)  # 将结果归一化，转化为概率值
        return softmax

    return train_logits


def conv_layer(input_tensor, filter_size, input_channel, output_channel):
    conv_w = get_weight(shape=[filter_size, filter_size, input_channel, output_channel],
                        stddev=0.1)  # [filter_W,filter_H,channel_I,channel_O]
    conv_b = get_bias(shape=[output_channel])
    conv = tf.nn.conv2d(input=input_tensor,
                        filter=conv_w,
                        strides=[1, 1, 1, 1],
                        padding="SAME")
    return tf.nn.relu(tf.nn.bias_add(conv, conv_b))


def pool_layer(input_tensor):
    # 池化操作
    pool = tf.nn.max_pool(input_tensor,
                          ksize=[1, 3, 3, 1],
                          strides=[1, 2, 2, 1],
                          padding="SAME")  # 最大池化 
    norm = tf.nn.lrn(pool, depth_radius=4, bias=1.0, alpha=0.001 / 9.0, beta=0.75)  # 局部响应归一化
    return norm


def fc_layer(input_tensor, input_channel, stddev_w, output_channel):
    # 全连接
    fc_w = get_weight(shape=[input_channel, output_channel], stddev=stddev_w)
    fc_b = get_bias(shape=[output_channel])
    fc = tf.matmul(input_tensor, fc_w)
    pre_activation = fc + fc_b
    return pre_activation


def losses(logits, labels):
    # 计算损失率
    cross_entropy = tf.nn.sparse_softmax_cross_entropy_with_logits(logits=logits, labels=labels)  # 计算交叉熵
    loss = tf.reduce_mean(input_tensor=cross_entropy)
    return loss


def evaluation(logits, labels):
    # 计算准确率
    correct = tf.nn.in_top_k(predictions=logits, targets=labels, k=1)
    correct = tf.cast(x=correct, dtype=tf.float16)
    accuracy = tf.reduce_mean(correct)
    return accuracy


def get_weight(shape, stddev):
    # 生成权重值
    weight = tf.truncated_normal(shape=shape, stddev=stddev, dtype=tf.float32)
    return tf.Variable(weight)


def get_bias(shape):
    # 生成偏置值
    bias = tf.constant(value=0.1, dtype=tf.float32, shape=shape)
    return tf.Variable(bias)
